{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torch.autograd import Variable\n",
    "import torch as t\n",
    "from early_stopping import EarlyStopping"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# ReLU + Sigmoid + Cross Entropy + L2 + early stopping"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Net(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(Net, self).__init__()\n",
    "        \n",
    "        self.fc1 = nn.Linear(28*28, 300)\n",
    "        self.fc2 = nn.Linear(300, 10)\n",
    "        \n",
    "    def forward(self, x):\n",
    "        x = x.view(x.shape[0], -1)\n",
    "        x = F.relu(self.fc1(x))\n",
    "        x = F.logsigmoid(self.fc2(x))\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from mnist_loader import load_data_shared, vectorized_result\n",
    "training_data1, validation_data1, _ = load_data_shared(filename=\"../mnist.pkl.gz\",\n",
    "                                                                     seed=666,\n",
    "                                                                     train_size=400,\n",
    "                                                                     vali_size=100,\n",
    "                                                                     test_size=0)\n",
    "training_data2, validation_data2, _ = load_data_shared(filename=\"../mnist.pkl.gz\",\n",
    "                                                                     seed=666,\n",
    "                                                                     train_size=400,\n",
    "                                                                     vali_size=100,\n",
    "                                                                     test_size=0)\n",
    "training_data3, validation_data3, _ = load_data_shared(filename=\"../mnist.pkl.gz\",\n",
    "                                                                     seed=666,\n",
    "                                                                     train_size=400,\n",
    "                                                                     vali_size=100,\n",
    "                                                                     test_size=0)\n",
    "_, _, test_data = load_data_shared(filename=\"../mnist.pkl.gz\",\n",
    "                                                                     seed=666,\n",
    "                                                                     train_size=0,\n",
    "                                                                     vali_size=0,\n",
    "                                                                     test_size=100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def predict(data, net, criterion):\n",
    "    with t.no_grad():\n",
    "        #for index in range(test_data[0].shape[0]):\n",
    "            # get the inputs\n",
    "        inputs, labels = t.Tensor(data[0]), t.Tensor(data[1])\n",
    "\n",
    "        # forward + backward + optimize\n",
    "        outputs = net(inputs)\n",
    "        _, predicted = t.max(outputs, 1)\n",
    "        loss = criterion(outputs, labels.long())\n",
    "\n",
    "        correct = (predicted == labels).sum().item()\n",
    "        accuracy = correct / data[0].shape[0]\n",
    "        return loss, accuracy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def fit(net, train_data, vali_data, criterion, optimizer, is_early_stopping):\n",
    "    loss_scores = []\n",
    "    validate_loss = []\n",
    "    \n",
    "    # initialize the early_stopping object\n",
    "    early_stopping = EarlyStopping(patience=3, verbose=False)\n",
    "    \n",
    "    for epoch in range(1000):  # loop over the dataset multiple times\n",
    "\n",
    "        # get the inputs\n",
    "        inputs, labels = t.Tensor(train_data[0]), t.Tensor(train_data[1])\n",
    "        vector_labels = t.Tensor([vectorized_result(y) for y in train_data[1]])\n",
    "        # zero the parameter gradients\n",
    "        optimizer.zero_grad()\n",
    "\n",
    "        # forward + backward + optimize\n",
    "        outputs = net(inputs)\n",
    "        loss = criterion(outputs, labels.long())\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        # print statistics\n",
    "        loss_scores.append(loss.item())\n",
    "        #train_scores.append(predict(training_data, net))\n",
    "        one_vali_loss, _ = predict(vali_data, net, criterion)\n",
    "        validate_loss.append(one_vali_loss)\n",
    "        \n",
    "        if is_early_stopping:\n",
    "            early_stopping(one_vali_loss, net)\n",
    "\n",
    "            if early_stopping.early_stop:\n",
    "                print(\"Early stopping at \", epoch)\n",
    "                break\n",
    "                \n",
    "    print('Finished Training')\n",
    "    return net"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#  子网络"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "EarlyStopping counter: 1 out of 3\n",
      "EarlyStopping counter: 1 out of 3\n",
      "EarlyStopping counter: 1 out of 3\n",
      "EarlyStopping counter: 1 out of 3\n",
      "EarlyStopping counter: 1 out of 3\n",
      "EarlyStopping counter: 1 out of 3\n",
      "EarlyStopping counter: 1 out of 3\n",
      "EarlyStopping counter: 1 out of 3\n",
      "EarlyStopping counter: 1 out of 3\n",
      "EarlyStopping counter: 1 out of 3\n",
      "EarlyStopping counter: 1 out of 3\n",
      "EarlyStopping counter: 1 out of 3\n",
      "EarlyStopping counter: 1 out of 3\n",
      "EarlyStopping counter: 1 out of 3\n",
      "EarlyStopping counter: 1 out of 3\n",
      "EarlyStopping counter: 1 out of 3\n",
      "EarlyStopping counter: 1 out of 3\n",
      "EarlyStopping counter: 1 out of 3\n",
      "EarlyStopping counter: 1 out of 3\n",
      "EarlyStopping counter: 1 out of 3\n",
      "EarlyStopping counter: 1 out of 3\n",
      "EarlyStopping counter: 1 out of 3\n",
      "EarlyStopping counter: 2 out of 3\n",
      "EarlyStopping counter: 1 out of 3\n",
      "EarlyStopping counter: 1 out of 3\n",
      "EarlyStopping counter: 1 out of 3\n",
      "EarlyStopping counter: 1 out of 3\n",
      "EarlyStopping counter: 1 out of 3\n",
      "EarlyStopping counter: 1 out of 3\n",
      "EarlyStopping counter: 2 out of 3\n",
      "EarlyStopping counter: 1 out of 3\n",
      "EarlyStopping counter: 2 out of 3\n",
      "EarlyStopping counter: 1 out of 3\n",
      "EarlyStopping counter: 2 out of 3\n",
      "EarlyStopping counter: 1 out of 3\n",
      "EarlyStopping counter: 1 out of 3\n",
      "EarlyStopping counter: 1 out of 3\n",
      "EarlyStopping counter: 2 out of 3\n",
      "EarlyStopping counter: 3 out of 3\n",
      "Early stopping at  696\n",
      "Finished Training\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(tensor(0.4662), 0.84)"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch.optim as optim\n",
    "net1 = Net()\n",
    "criterion1 = nn.CrossEntropyLoss()\n",
    "optimizer1 = optim.SGD(net1.parameters(), lr = 1e-1, weight_decay=1e-2)\n",
    "net1 = fit(net1, training_data1, validation_data1, criterion1, optimizer1, True)\n",
    "predict(test_data, net1, criterion1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "EarlyStopping counter: 1 out of 3\n",
      "EarlyStopping counter: 1 out of 3\n",
      "EarlyStopping counter: 1 out of 3\n",
      "EarlyStopping counter: 1 out of 3\n",
      "EarlyStopping counter: 1 out of 3\n",
      "EarlyStopping counter: 1 out of 3\n",
      "EarlyStopping counter: 1 out of 3\n",
      "EarlyStopping counter: 1 out of 3\n",
      "EarlyStopping counter: 1 out of 3\n",
      "EarlyStopping counter: 1 out of 3\n",
      "EarlyStopping counter: 1 out of 3\n",
      "EarlyStopping counter: 1 out of 3\n",
      "EarlyStopping counter: 1 out of 3\n",
      "EarlyStopping counter: 1 out of 3\n",
      "EarlyStopping counter: 1 out of 3\n",
      "EarlyStopping counter: 1 out of 3\n",
      "EarlyStopping counter: 1 out of 3\n",
      "EarlyStopping counter: 1 out of 3\n",
      "EarlyStopping counter: 1 out of 3\n",
      "EarlyStopping counter: 1 out of 3\n",
      "EarlyStopping counter: 1 out of 3\n",
      "EarlyStopping counter: 1 out of 3\n",
      "EarlyStopping counter: 1 out of 3\n",
      "EarlyStopping counter: 1 out of 3\n",
      "EarlyStopping counter: 1 out of 3\n",
      "EarlyStopping counter: 1 out of 3\n",
      "EarlyStopping counter: 2 out of 3\n",
      "EarlyStopping counter: 1 out of 3\n",
      "EarlyStopping counter: 2 out of 3\n",
      "EarlyStopping counter: 1 out of 3\n",
      "EarlyStopping counter: 1 out of 3\n",
      "EarlyStopping counter: 2 out of 3\n",
      "EarlyStopping counter: 1 out of 3\n",
      "EarlyStopping counter: 1 out of 3\n",
      "EarlyStopping counter: 1 out of 3\n",
      "EarlyStopping counter: 1 out of 3\n",
      "EarlyStopping counter: 1 out of 3\n",
      "EarlyStopping counter: 2 out of 3\n",
      "EarlyStopping counter: 1 out of 3\n",
      "EarlyStopping counter: 2 out of 3\n",
      "EarlyStopping counter: 1 out of 3\n",
      "EarlyStopping counter: 2 out of 3\n",
      "EarlyStopping counter: 1 out of 3\n",
      "EarlyStopping counter: 2 out of 3\n",
      "EarlyStopping counter: 1 out of 3\n",
      "EarlyStopping counter: 1 out of 3\n",
      "EarlyStopping counter: 1 out of 3\n",
      "EarlyStopping counter: 1 out of 3\n",
      "EarlyStopping counter: 2 out of 3\n",
      "EarlyStopping counter: 3 out of 3\n",
      "Early stopping at  711\n",
      "Finished Training\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(tensor(0.4531), 0.85)"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "net2 = Net()\n",
    "criterion2 = nn.CrossEntropyLoss()\n",
    "optimizer2 = optim.SGD(net2.parameters(), lr = 1e-1, weight_decay=1e-2)\n",
    "net2 = fit(net2, training_data2, validation_data2, criterion2, optimizer2, True)\n",
    "predict(test_data, net2, criterion2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "EarlyStopping counter: 1 out of 3\n",
      "EarlyStopping counter: 1 out of 3\n",
      "EarlyStopping counter: 1 out of 3\n",
      "EarlyStopping counter: 1 out of 3\n",
      "EarlyStopping counter: 1 out of 3\n",
      "EarlyStopping counter: 1 out of 3\n",
      "EarlyStopping counter: 1 out of 3\n",
      "EarlyStopping counter: 1 out of 3\n",
      "EarlyStopping counter: 1 out of 3\n",
      "EarlyStopping counter: 1 out of 3\n",
      "EarlyStopping counter: 1 out of 3\n",
      "EarlyStopping counter: 1 out of 3\n",
      "EarlyStopping counter: 1 out of 3\n",
      "EarlyStopping counter: 1 out of 3\n",
      "EarlyStopping counter: 1 out of 3\n",
      "EarlyStopping counter: 1 out of 3\n",
      "EarlyStopping counter: 1 out of 3\n",
      "EarlyStopping counter: 1 out of 3\n",
      "EarlyStopping counter: 1 out of 3\n",
      "EarlyStopping counter: 1 out of 3\n",
      "EarlyStopping counter: 1 out of 3\n",
      "EarlyStopping counter: 1 out of 3\n",
      "EarlyStopping counter: 1 out of 3\n",
      "EarlyStopping counter: 1 out of 3\n",
      "EarlyStopping counter: 1 out of 3\n",
      "EarlyStopping counter: 1 out of 3\n",
      "EarlyStopping counter: 1 out of 3\n",
      "EarlyStopping counter: 1 out of 3\n",
      "EarlyStopping counter: 1 out of 3\n",
      "EarlyStopping counter: 1 out of 3\n",
      "EarlyStopping counter: 1 out of 3\n",
      "EarlyStopping counter: 1 out of 3\n",
      "EarlyStopping counter: 2 out of 3\n",
      "EarlyStopping counter: 3 out of 3\n",
      "Early stopping at  716\n",
      "Finished Training\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(tensor(0.4567), 0.84)"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "net3 = Net()\n",
    "criterion3 = nn.CrossEntropyLoss()\n",
    "optimizer3 = optim.SGD(net3.parameters(), lr = 1e-1, weight_decay=1e-2)\n",
    "net3 = fit(net3, training_data3, validation_data3, criterion3, optimizer3, True)\n",
    "predict(test_data, net3, criterion3)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Bagging"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "def predict2(data, nets):\n",
    "    with t.no_grad():\n",
    "        inputs, labels = t.Tensor(data[0]), t.Tensor(data[1])\n",
    "        output = t.FloatTensor(data[0].shape[0], 10).zero_()\n",
    "        for net in nets:\n",
    "            output += net(inputs)\n",
    "        _, predicted = t.max(output, 1)\n",
    "\n",
    "        correct = (predicted == labels).sum().item()\n",
    "        accuracy = correct / data[0].shape[0]\n",
    "        return accuracy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.84"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "predict2(test_data, [net1, net2, net3])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "总的准确率比单个准确率低，这是为什么？  \n",
    "难道是相关性太大？"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
